from models.base import (
                         furnet
                         )


def get_network(network, depth, dataset, use_bn=True):
    if network == 'furnet':
        return furnet(dataset)
    else:
        raise NotImplementedError('Network unsupported ' + network)


def stablize_bn(net, trainloader, device='cuda'):
    """Iterate over the dataset for stabilizing the
    BatchNorm statistics.
    """
    net = net.train()
    for batch, (inputs, _) in enumerate(trainloader):
        inputs = inputs.to(device)
        net(inputs)
